import rospy
from frankapy import FrankaArm
from autolab_core import RigidTransform
from sensor_msgs.msg import Image
from autolab_core import RigidTransform, Point

import matplotlib.pyplot as plt
import numpy as np
import utils.transporter_utils as utils


import ipdb
st = ipdb.set_trace

OFFSET = [0.02, -0.02, 0]
extrinsic1 = np.array(
 [[-0.0172734,  0.9989287,  0.0429307, 0.531317],
  [0.9996286,  0.0163485,  0.0218036, -0.011965],
  [0.0210784,  0.0432913, -0.9988401, 0.789953],
  [0,0,0,1]])

POSE_CAMERA = RigidTransform(
        rotation=extrinsic1[:3, :3], 
        translation=extrinsic1[:3, 3],
        from_frame="camera",
        to_frame="world"
)

def imgMsg2cv2(msg):

    if(msg.encoding == 'bgra8'):
        image = np.frombuffer(msg.data, dtype=np.uint8)
        image = image.reshape([msg.height, msg.width, 4])
        image = image[:,:,:3] # We need only RGB channels, not alpha channel
    elif(msg.encoding == 'mono8'):
        image = np.frombuffer(msg.data, dtype=np.uint8)
        image = image.reshape([msg.height, msg.width])
    elif(msg.encoding == '32FC1'):
        image = np.frombuffer(msg.data, dtype=np.float32)
        image = image.reshape([msg.height, msg.width])
    elif(msg.encoding == 'mono16'):
        image = np.frombuffer(msg.data, dtype=np.float16)
        image = image.reshape([msg.height, msg.width])
    else:
        raise AssertionError("Image encoding != bgra8, 32FC1, currently support only bgra8, 32FC1")
    return image

def get_rgbd():
    msg = rospy.wait_for_message("/azcam_top/rgb/image_raw", Image)
    rgb =  imgMsg2cv2(msg)[360 - 240: 360 + 240, 640 - 350: 640 + 290, :][None]
    rgb = rgb[..., ::-1]
    msg = rospy.wait_for_message("/azcam_top/depth_to_rgb/image_raw", Image)
    depth = imgMsg2cv2(msg)
    depth = imgMsg2cv2(msg)[360 - 240: 360 + 240, 640 - 350: 640 + 290][None]

    return {
        "color": rgb, 
        "depth": depth
    }


def goto_(loc0, pose_camera, fa, offset):
    '''
    given pick and place pose then go there directly
    '''
    # pick 
    p_camera = loc0
    p_camera = Point(p_camera, 'camera')
    p_base = pose_camera * p_camera
    T_ee_world = fa.get_pose()
    T_ee_world.translation = p_base.data + [0.00,0,0.1] + offset
    fa.goto_pose(T_ee_world,5,ignore_virtual_walls=True)


def pick_and_place(pose, fa):
    '''
    given pick and place pose then go there directly
    '''
    loc0, ori0 = pose['pose0']
    loc1, ori1 = pose['pose1']
    # pick 
    p_camera = loc0
    p_camera = Point(p_camera, 'camera')
    p_base = POSE_CAMERA * p_camera
    fa.open_gripper()
    T_ee_world = fa.get_pose()
    T_ee_world.translation = p_base.data + [0.00,0,0.1] + OFFSET
    fa.goto_pose(T_ee_world,5,ignore_virtual_walls=True)

    T_ee_world.translation = p_base.data + [0.00,0,0.03] + OFFSET
    fa.goto_pose(T_ee_world,3,ignore_virtual_walls=True)
    fa.close_gripper()
    pick_height = T_ee_world.translation[2]

    T_ee_world.translation += [0,0.00,0.1]
    fa.goto_pose(T_ee_world,2,ignore_virtual_walls=True)

    # place
    p_camera = loc1
    p_camera = Point(p_camera, 'camera')
    p_base = POSE_CAMERA * p_camera
    T_ee_world = fa.get_pose()

    print('Rotation in end-effector frame')
    rad = utils.quatXYZW_to_eulerXYZ(ori1)[2]
    angle = rad*180/np.pi
    if angle > 180:
        angle = angle - 360
    elif angle < -180:
        angle = angle + 360
    if angle > 165:
        angle = 165
    if angle < -165:
        angle = -165
    print(angle)
    T_ee_rot = RigidTransform(
        rotation=RigidTransform.z_axis_rotation(np.deg2rad(0)),
        from_frame='franka_tool', to_frame='franka_tool'
    )
    T_ee_world_target = T_ee_world * T_ee_rot

    T_ee_world_target.translation = p_base.data + [0.00,0,0.1] + OFFSET
    fa.goto_pose(T_ee_world_target,4,ignore_virtual_walls=True)

    T_ee_world_target.translation = p_base.data + [0.00,0,0.03] + OFFSET
    T_ee_world_target.translation[2] = pick_height
    fa.goto_pose(T_ee_world_target,4,ignore_virtual_walls=True)

    fa.goto_gripper(0.055)

    T_ee_world_target.translation = p_base.data + [0.00,0,0.1] + OFFSET
    fa.goto_pose(T_ee_world_target,4,ignore_virtual_walls=True)

    fa.reset_joints()
    fa.open_gripper()


def get_obs(debug):
    if not debug:
        obs = get_rgbd()
    else:
        obs = load_img_debug()

    # img = np.concatenate([
    #     obs["color"],
    #     obs["depth"][..., None].repeat(3, axis=-1)
    # ], -1).transpose(1, 0, 2)
    # obs["color"] = obs["color"].transpose(1, 0, 2)
    # obs["depth"] = obs["depth"].transpose(1, 0)

    # obs["img"] = img
    return obs
    

def load_img_debug():
    d2 = np.load('utils/real_robot.npy', allow_pickle=True)
    img = d2.any().get("img")
    depth = d2.any().get("depth")

    img = img[360 - 240: 360 + 240, 640 - 350: 640 + 290, :][None]
    depth = depth[360 - 240: 360 + 240, 640 - 350: 640 + 290][None]
    return {
        "color": img,
        "depth": depth
    }
    # plt.imshow(img)
    # plt.scatter(x=[310, 50], y=[115, 50], c='r', s=10)
    # plt.savefig("rgb.png")
    # st()
    

if __name__ == "__main__":
    # fa = FrankaArm()
    # obs = get_rgbd()

    # img = obs['img']
    # depth = obs['depth']

    # pick = [280, 50]
    # place = [120, 50]

    # execute

    
    # data = {
    #     "img": img,
    #     "depth": depth
    # }
    # np.save("real_robot.npy", data)

    # d2 = np.load('real_robot.npy', allow_pickle=True)
    # img = d2.any().get("img")
    # depth = d2.any().get("depth")

    load_img_debug()